-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Max pool #163
Max pool #163
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
I'm just changing the executor implementation. Hence I don't think I need to add extra tests apart from what's already in CI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having an explicit op instead of decomposing feels reasonable.
Do we want the subsymbols of the poolXd to be the original verbose decomposition? My gut is actually 'no', i.e. that just about every backend would explicitly implement a pooling operator anyway. But wanted to throw it out there.
In general I'd recommend more """doc comments"""
on functions but I'm not going to hold off on a +1 over that. A comment I'd like to see somewhere is something to the effect of "we tried decomposing this as conv + X + Y + Z, but it leads to really bad perf and systems like nvFuser implement pooling directly anyway", i.e. explaining why this op exists and that the alternative isn't great.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @jjsjann123 @tfogal
What does this PR do?
Fixes #164.
We have restored thunder performance by having
torchex
runningmax_pool2d/3d
via a single aten call, versus using the decomposed primitive operations using convolution.A quick performance is demonstrated here:
This is before the PR:
After the PR:
Note this is only done for max_pool2d/3d. Because max_pool1d is implicitly differentiable in pytorch so there's no backward entry in aten.